# -*- coding: utf-8 -*-
'''
Created on 2017年4月24日

@author: ZhuJiahui506
'''

import os
import time
import numpy as np
from file_utils.file_reader import read_to_1d_list, read_to_2d_list
from file_utils.file_writer import quick_write_1d_to_text
from topic_utils.distribution_util import get_real_topics, sparse_topic_vsm

def update_topic(pl):
    '''
    根据P(Zi|Z-i,W)重新采样一个主题
    :param pl: P(Zi|Z-i,W)分布列
    '''
    if (np.min(pl) < 0):
        pl = pl + np.abs(np.min(pl)) + np.true_divide(1.0, len(pl))
    norm_pl = np.true_divide(pl, np.sum(pl))
    
    r = np.random.rand()

    index = -1
    while (r > 0):
        r = r - norm_pl[index]
        index = index + 1
    return index


def get_selected_embeddings(read_filename, selected_word_list, embedding_dis):
    '''
    从中文词向量文件中获取词汇列表对应的词向量
    :param read_filename: 词向量文件
    :param selected_word_list: 词汇列表，由词汇构成，1维
    :param embedding_dis: 词向量维数
    :return 词汇列表中的词汇构成的词向量矩阵(numpy 2d array)
    '''
    
    # 初始化矩阵
    selected_embeddings = np.zeros((len(selected_word_list), embedding_dis), dtype=float)
    with open(read_filename, 'r', encoding='utf-8') as f:
        for each_line in f:
            split_line = each_line.strip().split()
            this_word = split_line[0]
            try:
                word_index = selected_word_list.index(this_word)
                selected_embeddings[word_index, :] = [float(x) for x in split_line[1:]]
            except ValueError:
                pass
    
    #print(selected_embeddings)
    return selected_embeddings

def get_domain_knowledge(selected_embeddings, sim_word_num):
    '''
    根据词向量构造领域知识
    :param selected_embeddings: 词汇的词向量列表
    :param sim_word_num: 选择的相似词数
    :return 返回每个词汇对应的相似词汇(2d int list)
    '''
    embeddings_norm = np.linalg.norm(selected_embeddings, 2, 1)
    
    sim = np.zeros((len(selected_embeddings),len(selected_embeddings)))
    domain_knowledge = []
    for i in range(len(selected_embeddings)):
        sim[i][i] = 1
        for j in range(i+1,len(selected_embeddings)):
            if embeddings_norm[i] > 0.000001 and embeddings_norm[j] > 0.000001:
                result = np.true_divide(np.dot(selected_embeddings[i], selected_embeddings[j]), (embeddings_norm[i] * embeddings_norm[j]))
                sim[i][j] = result
                sim[j][i] = result
             
    for i in range(len(selected_embeddings)):
        if np.sum(sim[i]) < 1.00001 and np.sum(sim[i]) > 0.99999:
            domain_knowledge.append([i])
        else:
            sort_index = np.argsort(sim[i])
            dk = sort_index[(len(selected_embeddings) - sim_word_num):]
            domain_knowledge.append(list(dk))

    return domain_knowledge


def gpu_lda(doc_corpus, vocabulary, K, domain_knowledge, miu, dir_alpha=0.1, dir_beta=0.01, max_iter=50):
    '''
    基于广义玻利亚瓮模型的
    :param doc_corpus: 文档集合 二维列表 元素为词汇的编号 元素类型为str
    :param vocabulary: 词汇列表
    :param K: 主题个数
    :param domain_knowledge: 领域知识 二维列表 包含了每个词的语义相近的词集合 元素类型为int
    :param miu: GPU增强矩阵的参数
    :param dir_alpha: Dirichlet先验参数alpha 所有维取相同值
    :param dir_beta: Dirichlet先验参数alpha 所有维取相同值
    :param max_iter: Gibbs 采样迭代次数
    :return theta和phai两个矩阵(numpy 2d array)
    '''
    
    doc_num = len(doc_corpus)  # 文档数
    word_num = len(vocabulary)  # 词汇维度
    
    doc_topic_count = np.zeros((doc_num, K))  # 文档-主题共现统计
    topic_word_count = np.zeros((K, word_num))  # 主题-词汇共现统计
    topic_count = np.zeros(K)  # 主题-词汇总计 也就是所有词汇分配给各个主题的次数(任何词、任意次)
    topic_assignment = []  # 整个文档集的主题分配序列

    # 主题分配序列初始化
    for i in range(doc_num):
        line_topic_assignment = []  # 当前文档的主题分配序列
        initial_topic_assignment = np.random.randint(K)
        
        for word in doc_corpus[i]:
            # initial_topic_assignment = np.random.randint(K)
            line_topic_assignment.append(initial_topic_assignment)
            
            # 相应的计数+1
            doc_topic_count[i, initial_topic_assignment] += 1
            topic_word_count[initial_topic_assignment, int(word)] += 1
            topic_count[initial_topic_assignment] += 1
            
            for word_pie in domain_knowledge[int(word)]:
                if word_pie == int(word):
                    continue

                doc_topic_count[i, initial_topic_assignment] += miu
                topic_word_count[initial_topic_assignment, word_pie] += miu
                topic_count[initial_topic_assignment] += miu

        topic_assignment.append(line_topic_assignment)

    # GPU+Gibbs采样
    for t in range(max_iter):
        for i in range(doc_num):
            last_topic_assignment = -1
            for j in range(len(doc_corpus[i])):
                
                current_topic = topic_assignment[i][j]
                current_word = int(doc_corpus[i][j])
                
                # 相应的计数-1
                doc_topic_count[i, current_topic] -= 1
                topic_word_count[current_topic, current_word] -= 1
                topic_count[current_topic] -= 1

                for word_pie in domain_knowledge[current_word]:
                    if word_pie == current_word:
                        continue

                    doc_topic_count[i, current_topic] -= miu
                    topic_word_count[current_topic, word_pie] -= miu
                    topic_count[current_topic] -= miu
   
                # 计算P(Zi|Z-i,W)
                P_Z_W = np.true_divide((doc_topic_count[i] + dir_alpha) * (topic_word_count[:, current_word] + dir_beta), 
                                       (topic_count + word_num * dir_beta))
                
                if last_topic_assignment != -1:
                    P_Z_W[last_topic_assignment] += 0.01

                # 依据分布P(Zi|Z-i,W)进行采样 重新更新当前词汇的主题分配
                new_topic = update_topic(P_Z_W)
                topic_assignment[i][j] = new_topic
                last_topic_assignment = new_topic
                
                # 相应的计数+1
                doc_topic_count[i, new_topic] += 1
                topic_word_count[new_topic, current_word] += 1
                topic_count[new_topic] += 1
                for word_pie in domain_knowledge[current_word]:
                    if word_pie == current_word:
                        continue

                    doc_topic_count[i, new_topic] += miu
                    topic_word_count[new_topic, word_pie] += miu
                    topic_count[new_topic] += miu

    # 文档-主题分布 和 主题-词汇分布
    Theta = np.zeros((doc_num, K))
    Phai = np.zeros((K, word_num))
    
    doc_row_sum = np.sum(doc_topic_count, 1)
    topic_row_sum = np.sum(topic_word_count, 1)
    
    for i in range(doc_num):
        Theta[i, :] = np.true_divide(doc_topic_count[i, :] + dir_alpha, doc_row_sum[i] + K * dir_alpha)
    for i in range(K):
        Phai[i, :] = np.true_divide(topic_word_count[i, :] + dir_beta, topic_row_sum[i] + word_num * dir_beta)
    
    doc_row_sum = np.sum(Theta, 1)
    topic_row_sum = np.sum(Phai, 1)
    # 分布归一化
    for i in range(doc_num):
        Theta[i, :] = np.true_divide(Theta[i, :], doc_row_sum[i])
    for i in range(K):
        Phai[i, :] = np.true_divide(Phai[i, :], topic_row_sum[i])
    
    return Theta, Phai


def dk_test():

    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory = root_directory + 'dataset/text_model/tm_corpus/'
    write_directory = root_directory + 'dataset/DK_set'
    readvec_filename = root_directory + 'dataset/zhwiki_word_embedding.txt'

    if (not(os.path.exists(write_directory))):
        os.mkdir(write_directory)

    sim_word_num = 4
    
    for i in range(60, 70):
        start_time = time.clock()
        word_dictionary = [each.split(':')[1] for each in read_to_1d_list(read_directory + str(i) + '/'  + str(i) + '.vocab')]
        selected_embeddings = get_selected_embeddings(readvec_filename, word_dictionary, 200)
        domain_knowledge = get_domain_knowledge(selected_embeddings, sim_word_num)
        domain_knowledge_string = []
        for each in domain_knowledge:
            domain_knowledge_string.append(' '.join([str(x) for x in each]))
        quick_write_1d_to_text(write_directory + '/' + str(i) + '.txt', domain_knowledge_string)
        
        this_time = time.clock() - start_time
        print(this_time)
        
    print('create knowledge complete!')

def gpu_lda_test():

    latent_topic_number = 40
    
    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory = root_directory + 'dataset/text_model/tm_corpus/'
    read_filename = root_directory + 'dataset/zhwiki_word_embedding.txt'
    write_directory = root_directory + 'dataset/gpu_lda'
    write_directory1 = write_directory + '/feed_topic' + str(latent_topic_number)
    write_directory2 = write_directory + '/topic_word' + str(latent_topic_number)
    write_directory3 = write_directory + '/real_topic' + str(latent_topic_number)
    
    if (not(os.path.exists(write_directory))):
        os.mkdir(write_directory)
    if (not(os.path.exists(write_directory1))):
        os.mkdir(write_directory1)
    if (not(os.path.exists(write_directory2))):
        os.mkdir(write_directory2)
    if (not(os.path.exists(write_directory3))):
        os.mkdir(write_directory3)
        
    sim_word_num = 4

    for i in range(60, 70):
        start_time = time.clock()
        doc_corpus = read_to_2d_list(read_directory + str(i) + '/' + str(i) + '.docs', ' ')
        vocabulary = [each.split(':')[1] for each in read_to_1d_list(read_directory + str(i) + '/'  + str(i) + '.vocab')]
        selected_embeddings = get_selected_embeddings(read_filename, vocabulary, 200)
        dk_set = get_domain_knowledge(selected_embeddings, sim_word_num)
        
        THETA, PHAI = gpu_lda(doc_corpus, vocabulary, latent_topic_number, dk_set, 0.1, dir_alpha=(50.0 / latent_topic_number), dir_beta=0.01, max_iter=100)
        # THETA = sparse_topic_vsm(THETA, 5)
        PHAI = sparse_topic_vsm(PHAI, 10)
        real_topics = get_real_topics(PHAI, vocabulary)
        
        PHAI_to_string = []
        for j in range(len(PHAI)):
            str_line = " ".join([str(x) for x in PHAI[j]])
            PHAI_to_string.append(str_line)
        
        THETA_to_string = []
        for j in range(len(THETA)):
            str_line = " ".join([str(x) for x in THETA[j]])
            THETA_to_string.append(str_line)
            
        quick_write_1d_to_text(write_directory1 + '/' + str(i) + '.txt', THETA_to_string)
        quick_write_1d_to_text(write_directory2 + '/' + str(i) + '.txt', PHAI_to_string)
        quick_write_1d_to_text(write_directory3 + '/' + str(i) + '.txt', real_topics)
        
        this_time = time.clock() - start_time
        print(this_time)


if __name__ == '__main__':
    # dk_test()
    gpu_lda_test()
    